import config
import vdb

from models_llama2 import model_llama2_local
from models_chatgpt import model_chatgpt_3p5_turbo
from models_qw import model_qianwen_local
from sentence_transformers import SentenceTransformer

class model_embedding(metaclass=config.SingletonMeta):
    """
        向量化模型初始化
    """
    def __init__(self) -> None:
        super(model_embedding, self).__init__()
        self._eb = None
        for model in config.get_config().models.embedding:
            if model.enable == 1:
                self._eb = model
        if self._eb is None:
            raise Exception("未找到可用的embedding模型,请检查config.json配置文件")
        print(f"正在初始化embedding模型: {self._eb.name}")
        self._model = SentenceTransformer(model_name_or_path=self._eb.path)
    
    def encode(self, content):
        return self._model.encode(content)

class model_llm(metaclass=config.SingletonMeta):
    """
        LLM模型初始化
    """
    def __init__(self) -> None:
        super(model_llm, self).__init__()
        self._embedding = model_embedding()
        self._vdb = vdb.Milvus()

    def get_aviable_mode_names(self):
        """
            获取可用的模型模式的名称
        """
        names = []
        for llm in config.get_config().models.llms:
            if llm.enable == 1:
                names.append(llm.mode)
        return names

    def retrival_inference_answer(self, question, history, llm_mode: str):
        """
            根据传入的使用的语言模型参数，检索针对用户问题的可能的回答
        """
        vdb_answers = self._vdb.search_document([self._embedding.encode(question)], 3)

        retrival = ""
        if len(vdb_answers) == 0:
            retrival = "数据库中没有查询到相关的数据"

        for doc in vdb_answers:
            # todo 距离归一化处理
            if float(doc["distance"] < 600):
                retrival += f'问题:{doc["question"]} 答案:{doc["answer"]}\n'

        user_prompt = f"历史对话: {history}\n\n知识库: {retrival}\n根据知识进行回答用户问题: {question}"

        resp = "N/A"
        if llm_mode.__contains__("本地模式"):
            if llm_mode.__contains__("QWEN"):
                resp = model_qianwen_local().completion("", user_prompt)
            if llm_mode.__contains__("LLAMA2"):
                print(f"use : {llm_mode}")
                llma = model_llama2_local()
                resp = llma(user_prompt)

        if llm_mode.__contains__("ChatGPT3.5"):
            resp = model_chatgpt_3p5_turbo().completion("", user_prompt)

        return resp